{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8a6c76c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer\n",
    "import safetensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "001d3d03",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-72B-Instruct')\n",
    "streamer = TextStreamer(tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e02d61fd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2025-04-30 15:55:03,893] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/bin/ld: cannot find -laio: No such file or directory\n",
      "collect2: error: ld returned 1 exit status\n",
      "/usr/bin/ld: cannot find -lcufile: No such file or directory\n",
      "collect2: error: ld returned 1 exit status\n",
      "Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "66632c07204f4e66b99570b460ac8364",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/37 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    'Qwen/Qwen2.5-72B-Instruct', torch_dtype = torch.bfloat16,\n",
    "    device_map=\"auto\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2aeb4f1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "state_dict = model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4e5bea15",
   "metadata": {},
   "outputs": [],
   "source": [
    "from safetensors import safe_open\n",
    "\n",
    "folder = 'lora-embedding-128-qwen2.5-72b-malaysian-8k/checkpoint-1700'\n",
    "\n",
    "f = safe_open(f\"{folder}/adapter_model.safetensors\", framework=\"pt\", device='cpu')\n",
    "keys = f.keys()\n",
    "keys = sorted(list(set([k.split('.lora')[0] for k in keys if '.lora' in k])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "68b1747d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████| 562/562 [00:00<00:00, 576.70it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for k in tqdm(keys):\n",
    "    k_ori = k.replace('base_model.model.', '') + '.weight'\n",
    "    if 'embed_tokens' in k:\n",
    "        post_A = '.lora_embedding_A'\n",
    "        post_B = '.lora_embedding_B'\n",
    "    else:\n",
    "        post_A = '.lora_A.weight'\n",
    "        post_B = '.lora_B.weight'\n",
    "    A = k + post_A\n",
    "    B = k + post_B\n",
    "    \n",
    "    W = state_dict[k_ori]\n",
    "    if 'embed_tokens' not in k:\n",
    "        W = W.t()\n",
    "        \n",
    "    A = f.get_tensor(A).to(W.device)\n",
    "    B = f.get_tensor(B).to(W.device)\n",
    "    with torch.no_grad():\n",
    "        W.addmm_(A.t(), B.t(), alpha = 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2eef0786",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<|im_start|>system\n",
      "You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n",
      "<|im_start|>user\n",
      "butoh hang tu apa dalam bahsa kedah<|im_end|>\n",
      "<|im_start|>assistant\n",
      "Ia bermaksud 1. sesuatu yg menutup lubang (pintu dll): ~ "
     ]
    }
   ],
   "source": [
    "d = [\n",
    "    {'role': 'user', 'content': 'butoh hang tu apa dalam bahsa kedah'}\n",
    "]\n",
    "\n",
    "inputs = tokenizer.apply_chat_template(d, return_tensors = 'pt').to('cuda')\n",
    "generate_kwargs = dict(\n",
    "    input_ids=inputs,\n",
    "    max_new_tokens=512,\n",
    "    top_p=0.95,\n",
    "    top_k=50,\n",
    "    temperature=0.9,\n",
    "    do_sample=True,\n",
    "    repetition_penalty=1.05,\n",
    "    streamer=streamer\n",
    ")\n",
    "generation_output = model.generate(**generate_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd4c0ad7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
